"""
部分旋转 [:,:,20]
对 loss的研究，extraction_rate, c_init,c_fin
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import *
from disvae.models.losses import BtcvaeLoss, get_loss_s, _reconstruction_loss, _kl_normal_loss
from disvae.models import encoders, decoders
from disvae.models.vae import reparameterize
from exps.models import get_preds
from exps.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed

hyperparameter_defaults = dict(
    batch_size=128,
    learning_rate=5e-4,
    epochs=841,
    beta=10,
    dimension=3,
    img_id=0,
    random_seed=210,
    file='movement.py'
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()

wandb.init(project="exps", config=args, group='movement')
config = args
set_seed(config.random_seed)


def train(dl, model, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)

    for e in trange(epochs):
        storer = defaultdict(list)
        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            latent_dist = encoder(img)
            latent_sample = reparameterize(*latent_dist)

            loss = config.beta * _kl_normal_loss(*latent_dist, storer)
            for i, decoder in enumerate(decoder_list):
                recon = decoder(latent_sample[:, i:i + 1])
                recon_loss = _reconstruction_loss(img, recon)
                loss = loss + recon_loss

            opt.zero_grad()
            loss.backward()
            opt.step()

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if (e + 1) % 40 == 0:
            storer['recon'] = wandb.Image(recon[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
            _, index = kl_loss.sort(descending=True)

            preds, targets = get_preds(model, dl)
            fig = gt_vs_latent(preds[:, index[:3]].cpu(), targets.cpu())
            storer['correlated'] = wandb.Image(fig)

        storer['epoch'] = e
        wandb.log(storer, sync=True)
        plt.close()
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')
# img = data_generator.rotate(img,45)
dataset = data_generator.Translation(img)
# dataset.imgs = dataset.o_imgs[::4,::4].reshape(-1,1,64,64)
# dataset.labels = dataset.o_labels[::4,::4].reshape(-1,2)

epochs = config.epochs
dim = config.dimension

dl = DataLoader(dataset, pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)

encoder = encoders.EncoderBurgess((1, 64, 64), dim)
decoder_list = [decoders.DecoderBurgess((1, 64, 64), 1) for _ in range(dim)]
model = torch.nn.Sequential(*decoder_list[1:])
model.add_module('encoder', encoder)
model.add_module('decoder', decoder_list[0])

model.train()
model.cuda()

storer = train(dl, model, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
model.eval()
model.cpu()

preds, targets = get_preds(model, dl)
points = preds[:, index[:3]].cpu()
storer['point_cloud'] = wandb.Object3D(points.numpy())

# fig = plt_sample_traversal(None, model, 7, index, r=2)
# storer['traversal'] = wandb.Image(fig)

wandb.log(storer)
plt.show()
